// Copyright 2023 The XLS Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "xls/jit/block_jit.h"

#include <cstdint>
#include <cstring>
#include <iterator>
#include <memory>
#include <optional>
#include <sstream>
#include <string>
#include <string_view>
#include <utility>
#include <vector>

#include "absl/algorithm/container.h"
#include "absl/container/flat_hash_map.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_format.h"
#include "absl/types/span.h"
#include "llvm/include/llvm/IR/DataLayout.h"
#include "llvm/include/llvm/Support/Error.h"
#include "xls/codegen/block_inlining_pass.h"
#include "xls/codegen/codegen_options.h"
#include "xls/codegen/codegen_pass.h"
#include "xls/codegen/maybe_materialize_fifos_pass.h"
#include "xls/common/status/ret_check.h"
#include "xls/common/status/status_macros.h"
#include "xls/interpreter/block_evaluator.h"
#include "xls/interpreter/evaluator_options.h"
#include "xls/interpreter/observer.h"
#include "xls/ir/block.h"
#include "xls/ir/block_elaboration.h"
#include "xls/ir/clone_package.h"
#include "xls/ir/elaboration.h"
#include "xls/ir/events.h"
#include "xls/ir/instantiation.h"
#include "xls/ir/node.h"
#include "xls/ir/nodes.h"
#include "xls/ir/package.h"
#include "xls/ir/register.h"
#include "xls/ir/type.h"
#include "xls/ir/value.h"
#include "xls/ir/value_utils.h"
#include "xls/ir/xls_ir_interface.pb.h"
#include "xls/jit/aot_compiler.h"
#include "xls/jit/aot_entrypoint.pb.h"
#include "xls/jit/function_base_jit.h"
#include "xls/jit/jit_buffer.h"
#include "xls/jit/jit_callbacks.h"
#include "xls/jit/jit_evaluator_options.h"
#include "xls/jit/jit_runtime.h"
#include "xls/jit/llvm_compiler.h"
#include "xls/jit/observer.h"
#include "xls/jit/orc_jit.h"
#include "xls/passes/pass_base.h"

namespace xls {

namespace {

class CheckNoInstantiationsOnTop : public verilog::CodegenPass {
 public:
  CheckNoInstantiationsOnTop()
      : verilog::CodegenPass(
            "check_no_instantiations",
            "Check no instantiations remain on the top block") {}

 protected:
  absl::StatusOr<bool> RunInternal(
      Package* package, const verilog::CodegenPassOptions& options,
      PassResults* results, verilog::CodegenContext& context) const final {
    XLS_RET_CHECK(context.top_block()->GetInstantiations().empty())
        << "Jit is unable to implement instantiations.";
    return false;
  }
};

std::unique_ptr<verilog::CodegenCompoundPass> PrepareForJitPassPipeline() {
  auto passes = std::make_unique<verilog::CodegenCompoundPass>(
      "prepare_for_jit", "Process the IR to make it compatible with the jit");
  passes->Add<verilog::MaybeMaterializeFifosPass>();
  passes->Add<verilog::BlockInliningPass>();
  passes->Add<CheckNoInstantiationsOnTop>();
  return passes;
}

// Helper override for BlockJitContinuation that just adds renames for the
// registers.
//
// Note this relies on the reg_rename_map being generated by the
// PrepareForJitPassPipeline to ensure only one rename was performed.
class ElaboratedBlockJitContinuation : public BlockJitContinuation {
 public:
  ElaboratedBlockJitContinuation(
      const BlockJit::InterfaceMetadata& metadata, BlockJit* jit,
      const JittedFunctionBase& jit_func,
      const absl::flat_hash_map<std::string, std::string>& reg_rename_map,
      const absl::flat_hash_map<std::string, Type*>& materialized_impl_regs,
      BlockEvaluator::OutputPortSampleTime sample_time)
      : BlockJitContinuation(metadata, jit, jit_func, sample_time),
        reg_rename_map_(reg_rename_map),
        materialized_impl_regs_(materialized_impl_regs) {}

  absl::flat_hash_map<std::string, Value> GetRegistersMap() const override {
    absl::flat_hash_map<std::string, Value> base =
        BlockJitContinuation::GetRegistersMap();
    for (const auto& [orig, rename] : reg_rename_map_) {
      if (auto node = base.extract(rename)) {
        node.key() = orig;
        base.insert(std::move(node));
      }
    }
    return base;
  }
  absl::flat_hash_map<std::string, int64_t> GetRegisterIndices()
      const override {
    auto base = BlockJitContinuation::GetRegisterIndices();
    for (const auto& [orig, rename] : reg_rename_map_) {
      if (auto node = base.extract(rename)) {
        node.key() = orig;
        base.insert(std::move(node));
      }
    }
    return base;
  }
  absl::Status SetRegisters(
      const absl::flat_hash_map<std::string, Value>& regs) override {
    absl::flat_hash_map<std::string, Value> translated_regs = regs;
    for (const auto& [orig, rename] : reg_rename_map_) {
      if (auto node = translated_regs.extract(orig)) {
        node.key() = rename;
        translated_regs.insert(std::move(node));
      }
    }
    // Registers inserted to implement elaboration don't have any analogue on
    // the original elaboration so they might not be in the register set. If
    // they aren't present give them a default value.
    for (const auto& [reg_name, ty] : materialized_impl_regs_) {
      XLS_RET_CHECK(reg_rename_map_.contains(reg_name));
      std::string_view renamed = reg_rename_map_.at(reg_name);
      if (!translated_regs.contains(renamed)) {
        translated_regs[renamed] = ZeroOfType(ty);
      }
    }
    XLS_RETURN_IF_ERROR(BlockJitContinuation::SetRegisters(translated_regs));
    return absl::OkStatus();
  }

 private:
  // Map from path::style::name -> real_flat_style_name
  const absl::flat_hash_map<std::string, std::string>& reg_rename_map_;
  // Registers added to prepare for jitting (like fifo implementation
  // registers). These need to be manually setup to zero.
  const absl::flat_hash_map<std::string, Type*>& materialized_impl_regs_;
};

// Specialized block-jit that intercepts and renames some registers to match
// elaboration behavior.
class ElaboratedBlockJit : public BlockJit {
 public:
  std::unique_ptr<BlockJitContinuation> NewContinuation(
      BlockEvaluator::OutputPortSampleTime sample_time) override {
    return std::make_unique<ElaboratedBlockJitContinuation>(
        metadata_, this, function_, reg_rename_map_, materialized_impl_regs_,
        sample_time);
  }

 private:
  ElaboratedBlockJit(
      BlockJit::InterfaceMetadata metadata,
      absl::flat_hash_map<std::string, std::string> reg_rename_map,
      absl::flat_hash_map<std::string, Type*> materialized_impl_regs,
      std::unique_ptr<JitRuntime> runtime, std::unique_ptr<OrcJit> orc_jit,
      JittedFunctionBase function, bool support_observer_callbacks)
      : BlockJit(std::move(metadata), std::move(runtime), std::move(orc_jit),
                 std::move(function), support_observer_callbacks),
        reg_rename_map_(std::move(reg_rename_map)),
        materialized_impl_regs_(std::move(materialized_impl_regs)) {}

  absl::flat_hash_map<std::string, std::string> reg_rename_map_;
  absl::flat_hash_map<std::string, Type*> materialized_impl_regs_;

  friend class BlockJit;
};

struct ElaborationJitData {
  std::unique_ptr<Package> cloned_package;
  Block* inlined_block;
  absl::flat_hash_map<std::string, std::string> renamed_registers;
  absl::flat_hash_map<std::string, Type*> added_registers;
};

absl::StatusOr<ElaborationJitData> CloneElaborationPackage(
    const BlockElaboration& elab) {
  std::string_view top_name = elab.top()->block().value()->name();
  XLS_RET_CHECK(elab.top()->block())
      << "Top block of elaboration must be an XLS 'block' in order to use JIT";

  // The verifier should check this, but it's worth checking here too to avoid
  // confusing errors.
  XLS_RET_CHECK_EQ(absl::c_count_if(elab.package()->blocks(),
                                    [&](const std::unique_ptr<Block>& fb) {
                                      return fb->name() == top_name;
                                    }),
                   1)
      << "Multiple blocks have the same name as the top block. Unable to "
         "inline reliably.";
  XLS_ASSIGN_OR_RETURN(
      std::unique_ptr<Package> jit_package,
      ClonePackage(elab.package(),
                   absl::StrFormat("jit_clone_of_%s", elab.package()->name())));
  XLS_ASSIGN_OR_RETURN(Block * cloned_top, jit_package->GetBlock(top_name));
  verilog::CodegenContext codegen_context(cloned_top);
  PassResults results;
  const verilog::CodegenPassOptions opts{
      .codegen_options =
          verilog::CodegenOptions()
              .reset(FifoInstantiation::kResetPortName, /*asynchronous=*/false,
                     /*active_low=*/false, /*reset_data_path=*/false)
              // Force all FIFOs to be materialized.
              .set_fifo_module("")
              .set_nodata_fifo_module("")};
  XLS_RETURN_IF_ERROR(
      PrepareForJitPassPipeline()
          ->Run(jit_package.get(), opts, &results, codegen_context)
          .status());
  return ElaborationJitData{
      .cloned_package = std::move(jit_package),
      .inlined_block = codegen_context.top_block(),
      .renamed_registers = std::move(codegen_context.register_renames()),
      .added_registers = std::move(codegen_context.inserted_registers()),
  };
}

}  // namespace

/* static */ absl::StatusOr<BlockJit::InterfaceMetadata>
BlockJit::InterfaceMetadata::CreateFromBlock(Block* block) {
  InterfaceMetadata metadata;
  metadata.block_name = block->name();
  metadata.input_port_names.reserve(block->GetInputPorts().size());
  metadata.output_port_names.reserve(block->GetOutputPorts().size());
  metadata.register_names.reserve(block->GetRegisters().size());
  metadata.input_port_types.reserve(block->GetInputPorts().size());
  metadata.output_port_types.reserve(block->GetOutputPorts().size());
  metadata.register_types.reserve(block->GetRegisters().size());

  for (InputPort* ip : block->GetInputPorts()) {
    metadata.input_port_names.push_back(std::string(ip->name()));
    XLS_ASSIGN_OR_RETURN(
        Type * mapped_type,
        metadata.type_manager.MapTypeFromOtherArena(ip->GetType()));
    metadata.input_port_types.push_back(mapped_type);
  }

  for (OutputPort* op : block->GetOutputPorts()) {
    metadata.output_port_names.push_back(std::string(op->name()));
    XLS_ASSIGN_OR_RETURN(
        Type * mapped_type,
        metadata.type_manager.MapTypeFromOtherArena(op->port_type()));
    metadata.output_port_types.push_back(mapped_type);
  }

  for (Register* reg : block->GetRegisters()) {
    metadata.register_names.push_back(std::string(reg->name()));
    XLS_ASSIGN_OR_RETURN(
        Type * mapped_type,
        metadata.type_manager.MapTypeFromOtherArena(reg->type()));
    metadata.register_types.push_back(mapped_type);
  }

  return metadata;
}

/* static */ absl::StatusOr<BlockJit::InterfaceMetadata>
BlockJit::InterfaceMetadata::CreateFromAotEntrypoint(
    const AotEntrypointProto& entrypoint) {
  XLS_RET_CHECK_EQ(entrypoint.type(), AotEntrypointProto::BLOCK);
  XLS_RET_CHECK(entrypoint.has_block_metadata());
  const AotEntrypointProto::BlockMetadataProto& block_metadata_proto =
      entrypoint.block_metadata();
  InterfaceMetadata metadata;
  metadata.block_name = block_metadata_proto.block_interface().base().name();
  metadata.input_port_names.reserve(
      block_metadata_proto.block_interface().input_ports_size());
  metadata.input_port_types.reserve(
      block_metadata_proto.block_interface().input_ports_size());
  metadata.output_port_names.reserve(
      block_metadata_proto.block_interface().output_ports_size());
  metadata.output_port_types.reserve(
      block_metadata_proto.block_interface().output_ports_size());
  metadata.register_names.reserve(
      block_metadata_proto.block_interface().registers_size());
  metadata.register_types.reserve(
      block_metadata_proto.block_interface().registers_size());

  for (const PackageInterfaceProto::NamedValue& input_port :
       block_metadata_proto.block_interface().input_ports()) {
    metadata.input_port_names.push_back(input_port.name());
    XLS_ASSIGN_OR_RETURN(
        Type * mapped_type,
        metadata.type_manager.GetTypeFromProto(input_port.type()));
    metadata.input_port_types.push_back(mapped_type);
  }
  for (const PackageInterfaceProto::NamedValue& output_port :
       block_metadata_proto.block_interface().output_ports()) {
    metadata.output_port_names.push_back(output_port.name());
    XLS_ASSIGN_OR_RETURN(
        Type * mapped_type,
        metadata.type_manager.GetTypeFromProto(output_port.type()));
    metadata.output_port_types.push_back(mapped_type);
  }
  for (const PackageInterfaceProto::NamedValue& reg :
       block_metadata_proto.block_interface().registers()) {
    metadata.register_names.push_back(reg.name());
    XLS_ASSIGN_OR_RETURN(Type * mapped_type,
                         metadata.type_manager.GetTypeFromProto(reg.type()));
    metadata.register_types.push_back(mapped_type);
  }

  return metadata;
}

absl::StatusOr<std::unique_ptr<BlockJit>> BlockJit::Create(
    Block* block, bool support_observer_callbacks) {
  XLS_ASSIGN_OR_RETURN(BlockElaboration elab,
                       BlockElaboration::Elaborate(block));
  return BlockJit::Create(elab, support_observer_callbacks);
}

absl::StatusOr<std::unique_ptr<BlockJit>> BlockJit::Create(
    const BlockElaboration& elab, bool support_observer_callbacks) {
  Block* block;
  XLS_ASSIGN_OR_RETURN(
      std::unique_ptr<OrcJit> orc_jit,
      OrcJit::Create(
          LlvmCompiler::kDefaultOptLevel,
          /*include_observer_callbacks=*/support_observer_callbacks));
  XLS_ASSIGN_OR_RETURN(auto data_layout, orc_jit->CreateDataLayout());
  auto jit_runtime = std::make_unique<JitRuntime>(data_layout);
  if (elab.top()->block() &&
      (*elab.top()->block())->GetInstantiations().empty()) {
    block = elab.blocks().front();
    XLS_ASSIGN_OR_RETURN(InterfaceMetadata metadata,
                         InterfaceMetadata::CreateFromBlock(block));
    XLS_ASSIGN_OR_RETURN(
        auto function,
        JittedFunctionBase::Build(block, *orc_jit, EvaluatorOptions()));
    return std::unique_ptr<BlockJit>(new BlockJit(
        std::move(metadata), std::move(jit_runtime), std::move(orc_jit),
        std::move(function), support_observer_callbacks));
  }
  XLS_ASSIGN_OR_RETURN(ElaborationJitData jit_data,
                       CloneElaborationPackage(elab));
  XLS_ASSIGN_OR_RETURN(JittedFunctionBase jit_entrypoint,
                       JittedFunctionBase::Build(jit_data.inlined_block,
                                                 *orc_jit, EvaluatorOptions()));
  XLS_ASSIGN_OR_RETURN(
      InterfaceMetadata metadata,
      InterfaceMetadata::CreateFromBlock(jit_data.inlined_block));

  // jit_data.added_registers has types that come from the cloned package, but
  // we're only going to keep the package in metadata. We need to map the types
  // to the package we're keeping.
  for (auto& [_, reg_type] : jit_data.added_registers) {
    XLS_ASSIGN_OR_RETURN(Type * mapped_type,
                         metadata.type_manager.MapTypeFromOtherArena(reg_type));
    reg_type = mapped_type;
  }
  return std::unique_ptr<BlockJit>(new ElaboratedBlockJit(
      std::move(metadata), std::move(jit_data.renamed_registers),
      std::move(jit_data.added_registers), std::move(jit_runtime),
      std::move(orc_jit), std::move(jit_entrypoint),
      support_observer_callbacks));
}

/* static */ absl::StatusOr<std::unique_ptr<BlockJit>> BlockJit::CreateFromAot(
    const AotEntrypointProto& entrypoint, std::string_view data_layout,
    JitFunctionType func_ptr) {
  XLS_ASSIGN_OR_RETURN(
      JittedFunctionBase jfb,
      JittedFunctionBase::BuildFromAot(entrypoint, func_ptr,
                                       /*packed_entrypoint=*/std::nullopt));
  llvm::Expected<llvm::DataLayout> layout =
      llvm::DataLayout::parse(data_layout);
  XLS_RET_CHECK(layout) << "bad layout: " << data_layout;
  XLS_ASSIGN_OR_RETURN(InterfaceMetadata metadata,
                       InterfaceMetadata::CreateFromAotEntrypoint(entrypoint));
  absl::flat_hash_map<std::string, std::string> renamed_regs;
  XLS_RET_CHECK(entrypoint.has_block_metadata());
  const AotEntrypointProto::BlockMetadataProto& block_metadata_proto =
      entrypoint.block_metadata();
  renamed_regs.insert(block_metadata_proto.register_aliases().begin(),
                      block_metadata_proto.register_aliases().end());
  absl::flat_hash_map<std::string, Type*> added_regs;
  added_regs.reserve(block_metadata_proto.added_registers_size());
  for (const auto& [reg_name, reg_ty] :
       block_metadata_proto.added_registers()) {
    XLS_ASSIGN_OR_RETURN(added_regs[reg_name],
                         metadata.type_manager.GetTypeFromProto(reg_ty));
  }
  return std::unique_ptr<BlockJit>(new ElaboratedBlockJit(
      std::move(metadata), std::move(renamed_regs), std::move(added_regs),
      std::make_unique<JitRuntime>(*layout),
      /*orc_jit=*/nullptr, std::move(jfb),
      /*support_observer_callbacks=*/false));
}

/* static */ absl::StatusOr<JitObjectCode> BlockJit::CreateObjectCode(
    const BlockElaboration& elab, const JitEvaluatorOptions& jit_options) {
  XLS_ASSIGN_OR_RETURN(std::unique_ptr<AotCompiler> comp,
                       AotCompiler::Create(jit_options));
  XLS_ASSIGN_OR_RETURN(llvm::DataLayout data_layout, comp->CreateDataLayout());
  // NB We could avoid doing a package clone if there are no instantations but
  // since this is aot anyway its easier to just not bother. The cloned package
  // isn't going to be long lived anyway.
  XLS_ASSIGN_OR_RETURN(ElaborationJitData jit_data,
                       CloneElaborationPackage(elab));
  XLS_ASSIGN_OR_RETURN(
      auto function,
      JittedFunctionBase::Build(jit_data.inlined_block, *comp,
                                EvaluatorOptions(), jit_options.symbol_salt()));
  XLS_ASSIGN_OR_RETURN(auto obj_code, std::move(comp)->GetObjectCode());
  return JitObjectCode{
      .object_code = std::move(obj_code),
      .entrypoints =
          {
              FunctionEntrypoint{
                  .function = jit_data.inlined_block,
                  .jit_info = std::move(function),
                  .register_aliases = std::move(jit_data.renamed_registers),
                  .added_registers = std::move(jit_data.added_registers),
              },
          },
      .data_layout = data_layout,
      .package = std::move(jit_data.cloned_package),
  };
}

std::unique_ptr<BlockJitContinuation> BlockJit::NewContinuation(
    BlockEvaluator::OutputPortSampleTime sample_time) {
  return std::unique_ptr<BlockJitContinuation>(
      new BlockJitContinuation(metadata_, this, function_, sample_time));
}

absl::Status BlockJit::ReconcileMultipleRegisterWrites(
    BlockJitContinuation& continuation) {
  // Save away active register writes and clear it for next time.
  absl::flat_hash_map<int64_t, std::vector<int64_t>> active_register_writes =
      std::move(continuation.callbacks_.active_register_writes);
  continuation.callbacks_.active_register_writes.clear();

  absl::Span<uint8_t* const> register_output_pointers =
      continuation.output_arg_set().get_element_pointers().subspan(
          metadata_.OutputPortCount(), metadata_.RegisterCount());
  absl::Span<uint8_t* const> extra_register_write_pointers =
      continuation.output_arg_set().get_element_pointers().subspan(
          metadata_.OutputPortCount() + metadata_.RegisterCount());
  if (!extra_register_write_pointers.empty()) {
    for (int64_t reg_no = 0; reg_no < metadata_.RegisterCount(); ++reg_no) {
      auto it = active_register_writes.find(reg_no);
      if (it != active_register_writes.end()) {
        const std::vector<int64_t>& activated_reg_writes = it->second;
        XLS_RET_CHECK_GE(activated_reg_writes.size(), 1);
        if (activated_reg_writes.size() == 1) {
          // Only one register write activated. Ensure the activated value ends
          // up in the register buffer (the buffer of the first register write).
          if (activated_reg_writes.front() == 0) {
            // The first register write (maybe only) activated. Nothing to do.
          } else {
            // Copy the value to the register output buffer.
            register_output_pointers[reg_no];
            extra_register_write_pointers[activated_reg_writes.front() - 1];
            GetRegisterBufferMetadata()[reg_no];
            memcpy(
                register_output_pointers[reg_no],
                extra_register_write_pointers[activated_reg_writes.front() - 1],
                GetRegisterBufferMetadata()[reg_no].size);
          }
        } else {
          return absl::InternalError(absl::StrFormat(
              "Multiple writes of register `%s` activated in the same cycle",
              metadata_.register_names[reg_no]));
        }
      }
    }
  }
  return absl::OkStatus();
}

absl::Status BlockJit::RunOneCycle(BlockJitContinuation& continuation) {
  if (RuntimeObserver* observer = continuation.observer();
      observer != nullptr) {
    observer->Tick();
  }
  // Run to update the registers
  InterpreterEvents fake_events;
  function_.RunJittedFunction(
      continuation.input_arg_set(), continuation.output_arg_set(),
      continuation.temp_buffer_, &continuation.GetEvents(),
      /*instance_context=*/&continuation.callbacks_, runtime_.get(),
      /*continuation_point=*/0);
  XLS_RETURN_IF_ERROR(ReconcileMultipleRegisterWrites(continuation));

  // Finalize the register writes by moving them to the read side.
  continuation.SwapRegisters();
  if (continuation.sample_time() ==
      BlockEvaluator::OutputPortSampleTime::kAfterLastClock) {
    // Run again to get the output wires
    function_.RunJittedFunction(continuation.input_arg_set(),
                                *continuation.after_last_clock_output_set_,
                                continuation.temp_buffer_, &fake_events,
                                /*instance_context=*/&continuation.callbacks_,
                                runtime_.get(),
                                /*continuation_point=*/0);
  }
  return absl::OkStatus();
}

namespace {

// Concatenates the points from `buffers` and returns the resulting vector.
std::vector<uint8_t*> ComposeBuffers(absl::Span<JitBuffer* const> buffers) {
  std::vector<uint8_t*> result;
  for (JitBuffer* buffer : buffers) {
    absl::c_copy(buffer->pointers, std::back_inserter(result));
  }
  return result;
}

}  // namespace

BlockJitContinuation::BlockJitContinuation(
    const BlockJit::InterfaceMetadata& metadata, BlockJit* jit,
    const JittedFunctionBase& jit_func,
    BlockEvaluator::OutputPortSampleTime sample_time)
    : metadata_(metadata),
      block_jit_(jit),
      sample_time_(sample_time),

      input_port_buffers_(
          AllocateAlignedBuffer(jit->GetInputPortBufferMetadata())),
      output_port_buffers_(
          AllocateAlignedBuffer(jit->GetOutputPortBufferMetadata())),
      register_buffers_({AllocateAlignedBuffer(jit->GetRegisterBufferMetadata(),
                                               /*zero=*/true),
                         AllocateAlignedBuffer(jit->GetRegisterBufferMetadata(),
                                               /*zero=*/true)}),
      // These register writes are those writes beyond the first register write
      // for a register.
      extra_register_write_buffers_(
          AllocateAlignedBuffer(jit->GetExtraRegisterWriteBufferMetadata())),

      // The layout of the input buffers are:
      //
      //   {...input ports..,
      //    ...register reads...}
      //
      // To enable copyfree register value reuse from output of one cycle to
      // input of the next cycle, the input register buffers of input_sets_[0]
      // alias the output register buffers of output_sets_[1] and vice versa.
      input_sets_(
          {JitArgumentSet(
               &jit_func,
               ComposeBuffers({&input_port_buffers_, &register_buffers_[0]}),
               /*is_inputs=*/true,
               /*is_outputs=*/false),
           JitArgumentSet(
               &jit_func,
               ComposeBuffers({&input_port_buffers_, &register_buffers_[1]}),
               /*is_inputs=*/true,
               /*is_outputs=*/false)}),

      // The layout of the output buffers are:
      //
      //   {...output ports..,
      //    ...first register writes...,
      //    ...extra register writes... }
      //
      // The "first" register writes are the set of first elements in
      // Block::GetRegisterWrites. The "extra" register writes are the second
      // and further register writes from Block::GetRegisterWrites. After
      // running the JIT-ed code, registers with multiple writes are reconciled
      // and the next register value is written into the "first" register write
      // buffer so these buffers can be used as input register values for the
      // next cycle.
      output_sets_(
          {JitArgumentSet(
               &jit_func,
               ComposeBuffers({&output_port_buffers_, &register_buffers_[1],
                               &extra_register_write_buffers_}),
               /*is_inputs=*/false,
               /*is_outputs=*/true),
           JitArgumentSet(
               &jit_func,
               ComposeBuffers({&output_port_buffers_, &register_buffers_[0],
                               &extra_register_write_buffers_}),
               /*is_inputs=*/false,
               /*is_outputs=*/true)}),

      after_last_clock_output_set_(
          sample_time_ == BlockEvaluator::OutputPortSampleTime::kAfterLastClock
              ? jit_func.CreateOutputBuffer()
              : nullptr),
      temp_buffer_(jit_func.CreateTempBuffer()),
      callbacks_(InstanceContext::CreateForBlock()) {}

absl::Status BlockJitContinuation::SetInputPorts(
    absl::Span<const Value> values) {
  XLS_RET_CHECK_EQ(metadata_.InputPortCount(), values.size());
  auto it = values.cbegin();
  for (int64_t i = 0; i < metadata_.InputPortCount(); ++i) {
    Type* ip_type = metadata_.input_port_types[i];
    XLS_RET_CHECK(ValueConformsToType(*it, ip_type))
        << "input port " << metadata_.input_port_names[i]
        << " cannot be set to value of " << *it
        << " due to type mismatch with input port type of "
        << ip_type->ToString();
    ++it;
  }
  return block_jit_->runtime()->PackArgs(values, metadata_.input_port_types,
                                         input_port_pointers());
}

absl::Status BlockJitContinuation::SetInputPorts(
    absl::Span<const uint8_t* const> inputs) {
  XLS_RET_CHECK_EQ(metadata_.InputPortCount(), inputs.size());
  // TODO(allight): This is a lot of copying. We could do this more efficiently
  for (int i = 0; i < inputs.size(); ++i) {
    memcpy(input_port_pointers()[i], inputs[i],
           block_jit_->GetInputPortBufferMetadata()[i].size);
  }
  return absl::OkStatus();
}

absl::Status BlockJitContinuation::SetInputPorts(
    const absl::flat_hash_map<std::string, Value>& inputs) {
  std::vector<Value> values(metadata_.InputPortCount());
  auto input_indices = GetInputPortIndices();
  for (const auto& [name, value] : inputs) {
    if (!input_indices.contains(name)) {
      return absl::InvalidArgumentError(
          absl::StrFormat("Block has no input port '%s'", name));
    }
    values[input_indices.at(name)] = value;
  }
  if (metadata_.InputPortCount() != inputs.size()) {
    std::ostringstream oss;
    for (const auto& ip_name : metadata_.input_port_names) {
      if (!inputs.contains(ip_name)) {
        oss << "\n\tMissing input for port '" << ip_name << "'";
      }
    }
    return absl::InvalidArgumentError(
        absl::StrFormat("Expected %d input port values but only got %d:%s",
                        values.size(), inputs.size(), oss.str()));
  }
  return SetInputPorts(values);
}

absl::Status BlockJitContinuation::SetRegisters(
    absl::Span<const Value> values) {
  XLS_RET_CHECK_EQ(metadata_.RegisterCount(), values.size());
  auto it = values.cbegin();
  for (int64_t i = 0; i < metadata_.RegisterCount(); ++i) {
    XLS_RET_CHECK(ValueConformsToType(*it, metadata_.register_types[i]))
        << "register " << metadata_.register_names[i]
        << " cannot be set to value of " << *it
        << " due to type mismatch with register type of "
        << metadata_.register_types[i]->ToString();
    ++it;
  }
  return block_jit_->runtime()->PackArgs(values, metadata_.register_types,
                                         register_pointers());
}

absl::Status BlockJitContinuation::SetRegisters(
    absl::Span<const uint8_t* const> regs) {
  XLS_RET_CHECK_EQ(metadata_.RegisterCount(), regs.size());
  // TODO(allight): This is a lot of copying. We could do this more efficiently
  for (int i = 0; i < regs.size(); ++i) {
    memcpy(register_pointers()[i], regs[i],
           block_jit_->GetRegisterBufferMetadata()[i].size);
  }
  return absl::OkStatus();
}

absl::Status BlockJitContinuation::SetRegisters(
    const absl::flat_hash_map<std::string, Value>& regs) {
  auto reg_indices = BlockJitContinuation::GetRegisterIndices();
  std::vector<Value> values(reg_indices.size());
  for (const auto& [name, value] : regs) {
    if (!reg_indices.contains(name)) {
      return absl::InvalidArgumentError(
          absl::StrFormat("Block has no register '%s'", name));
    }
    values[reg_indices.at(name)] = value;
  }

  if (metadata_.RegisterCount() != regs.size()) {
    std::ostringstream oss;
    for (const auto& reg_name : metadata_.register_names) {
      if (!regs.contains(reg_name)) {
        oss << "\n\tMissing value for register '" << reg_name << "'";
      }
    }
    return absl::InvalidArgumentError(
        absl::StrFormat("Expected %d register values but only got %d:%s",
                        reg_indices.size(), regs.size(), oss.str()));
  }
  return SetRegisters(values);
}

std::vector<Value> BlockJitContinuation::GetOutputPorts() const {
  std::vector<Value> result;
  result.reserve(output_port_pointers().size());
  int i = 0;
  for (auto ptr : output_port_pointers()) {
    result.push_back(block_jit_->runtime()->UnpackBuffer(
        ptr, metadata_.output_port_types[i++]));
  }
  return result;
}

absl::flat_hash_map<std::string, int64_t>
BlockJitContinuation::GetInputPortIndices() const {
  absl::flat_hash_map<std::string, int64_t> ret;
  int i = 0;
  for (const auto& name : metadata_.input_port_names) {
    ret[name] = i++;
  }
  return ret;
}

absl::flat_hash_map<std::string, int64_t>
BlockJitContinuation::GetOutputPortIndices() const {
  absl::flat_hash_map<std::string, int64_t> ret;
  int i = 0;
  for (const auto& name : metadata_.output_port_names) {
    ret[name] = i++;
  }
  return ret;
}

absl::flat_hash_map<std::string, int64_t>
BlockJitContinuation::GetRegisterIndices() const {
  absl::flat_hash_map<std::string, int64_t> ret;
  int i = 0;
  for (const auto& name : metadata_.register_names) {
    ret[name] = i++;
  }
  return ret;
}

absl::flat_hash_map<std::string, Value>
BlockJitContinuation::GetOutputPortsMap() const {
  absl::flat_hash_map<std::string, Value> result;
  result.reserve(output_port_pointers().size());
  auto regs = GetOutputPorts();
  for (const auto& [name, off] : GetOutputPortIndices()) {
    result[name] = regs[off];
  }
  return result;
}

std::vector<Value> BlockJitContinuation::GetRegisters() const {
  std::vector<Value> result;
  result.reserve(register_pointers().size());
  int i = 0;
  for (auto ptr : register_pointers()) {
    result.push_back(block_jit_->runtime()->UnpackBuffer(
        ptr, metadata_.register_types[i++]));
  }
  return result;
}

absl::flat_hash_map<std::string, Value> BlockJitContinuation::GetRegistersMap()
    const {
  absl::flat_hash_map<std::string, Value> result;
  result.reserve(register_pointers().size());
  auto regs = GetRegisters();
  for (const auto& [name, off] : BlockJitContinuation::GetRegisterIndices()) {
    result[name] = regs[off];
  }
  return result;
}

namespace {
// Helper adapter to implement the interpreter-focused block-continuation api
// used by eval_proc_main. This holds live all the values needed to run the
// block-jit.
class BlockContinuationJitWrapper final : public BlockContinuation {
 public:
  BlockContinuationJitWrapper(std::unique_ptr<BlockJitContinuation>&& cont,
                              std::unique_ptr<BlockJit>&& jit)
      : continuation_(std::move(cont)), jit_(std::move(jit)) {}
  JitRuntime* runtime() const { return jit_->runtime(); }
  const absl::flat_hash_map<std::string, Value>& output_ports() final {
    if (!temporary_outputs_) {
      temporary_outputs_.emplace(continuation_->GetOutputPortsMap());
    }
    return *temporary_outputs_;
  }
  BlockEvaluator::OutputPortSampleTime sample_time() const final {
    return continuation_->sample_time();
  }
  const absl::flat_hash_map<std::string, Value>& registers() final {
    if (!temporary_regs_) {
      temporary_regs_.emplace(continuation_->GetRegistersMap());
    }
    return *temporary_regs_;
  }
  const InterpreterEvents& events() final { return continuation_->GetEvents(); }
  absl::Status RunOneCycle(
      const absl::flat_hash_map<std::string, Value>& inputs) final {
    temporary_outputs_.reset();
    temporary_regs_.reset();
    continuation_->ClearEvents();
    XLS_RETURN_IF_ERROR(continuation_->SetInputPorts(inputs));
    return jit_->RunOneCycle(*continuation_);
  }
  absl::Status SetRegisters(
      const absl::flat_hash_map<std::string, Value>& regs) final {
    temporary_regs_.reset();
    return continuation_->SetRegisters(regs);
  }

  void ClearObserver() override {
    continuation_->ClearObserver();
    eval_observer_.reset();
  }
  absl::Status SetObserver(EvaluationObserver* obs) override {
    ClearObserver();
    std::optional<RuntimeObserver*> run = obs->AsRawObserver();
    if (run) {
      return continuation_->SetObserver(*run);
    }
    eval_observer_.emplace(
        obs,
        [](int64_t ptr) -> Node* {
          return reinterpret_cast<Node*>(static_cast<intptr_t>(ptr));
        },
        jit_->runtime());
    return continuation_->SetObserver(&eval_observer_.value());
  }

 private:
  std::unique_ptr<BlockJitContinuation> continuation_;
  std::unique_ptr<BlockJit> jit_;
  // Holder for the data we return out of output_ports so that we can reduce
  // copying.
  std::optional<absl::flat_hash_map<std::string, Value>> temporary_outputs_;
  // Holder for the data we return out of registers so that we can reduce
  // copying.
  std::optional<absl::flat_hash_map<std::string, Value>> temporary_regs_;

  // Node value observer adapter from jit api to Value.
  std::optional<RuntimeEvaluationObserverAdapter> eval_observer_;
};
}  // namespace

absl::StatusOr<std::unique_ptr<BlockContinuation>>
JitBlockEvaluator::MakeNewContinuation(
    BlockElaboration&& elaboration,
    const absl::flat_hash_map<std::string, Value>& initial_registers,
    BlockEvaluator::OutputPortSampleTime sample_time) const {
  XLS_ASSIGN_OR_RETURN(
      auto jit,
      BlockJit::Create(elaboration,
                       /*support_observer_callbacks=*/supports_observer_));
  auto jit_cont = jit->NewContinuation(sample_time);
  XLS_RETURN_IF_ERROR(jit_cont->SetRegisters(initial_registers));
  return std::make_unique<BlockContinuationJitWrapper>(std::move(jit_cont),
                                                       std::move(jit));
}

absl::StatusOr<JitRuntime*> JitBlockEvaluator::GetRuntime(
    BlockContinuation* cont) const {
  BlockContinuationJitWrapper* cont_wrap =
      dynamic_cast<BlockContinuationJitWrapper*>(cont);
  if (cont_wrap == nullptr) {
    return absl::InvalidArgumentError("Not a jit continuation");
  }
  return cont_wrap->runtime();
}

}  // namespace xls
